#! /usr/bin/python3

#command! F :execute '%! /your_path/auto_assert.py %'

import os
import time
import sys
import re
import argparse

def find_signals(line): #{{{
  line = re.sub(r"\b'[bdh][0-9a-zA-Z_]+\b", " ", line) #line = re.sub(r"\b\d+'[bdh][0-9a-zA-Z_]+\b", " ", line)
  pattern = r"\b[a-zA-Z_]\w*\b"
  signals = re.findall(pattern, line)
  unique_signals = list(dict.fromkeys(signals))
  return ", ".join(unique_signals)
#}}}

def sys_rtl_file(file):
  global assert_clk_name
  global assert_rst_name
  assert_clk_name = "clk"
  assert_rst_name = "rst_n"
  signal_info = {}

  with open (file, "r") as rtl:
    handle = rtl.readlines()
    for line in handle:
      result = re.search(r"^\s*(input|output|wire|reg)\s*(wire|reg)?\s*(\[[\*\+\/\s\-\:\w]+\])?\s*(\w+)[\;\,\s]+(//.*)?", line)
      re_clk = re.search(r"^\s*wire\s+assert_clk\s*=\s*(\w+)\s*;", line)
      re_rst = re.search(r"^\s*wire\s+assert_rst_n\s*=\s*(\w+)\s*;", line)
      if result:
        info = {}
        signal_name = result.group(4)
        if result.group(2):
          info['type'] = result.group(2)
          info['port'] = result.group(1)
        else:
          info['type'] = result.group(1)
          info['port'] = "none"
        if result.group(3):
          signal_width = result.group(3).replace(r"]", "").replace(r"[", "")
          if re.match(r"(\w+)\s*\:0", signal_width):
            #print(re.match(r"(\w+)\s*\:0", signal_width).group(1))
            signal_width = int(re.match(r"(\w+)\s*\:0", signal_width).group(1)) + 1
          elif re.match(r"(\w+)\s*\-1\:0", signal_width):
            signal_width = re.match(r"(\w+)\s*\-1\:0", signal_width).group(1)
          info['width'] = signal_width
        else:
          info['width'] = 1
        if result.group(5):
          info['note'] = result.group(5).replace(r"//", "")
        signal_info[signal_name] = info
      if re_clk:
        assert_clk_name = re_clk.group(1)
      if re_rst:
        assert_rst_name = re_rst.group(1)
  return signal_info

def add_chk_xz(signal, note):
  if not re.match(r"chk_xz", note):
    return -1
  assert_inst = "chk_%s_xz_assert" % signal
  if re.match(r"chk_xz while (.*)", note):
    #str = "assert property chk_xz_valid(%s, %s) else $error(\"%s has xz error\", %s);" % (signal, re.match(r"chk_xz while (.*)", note).group(1), r"%s", signal)
    #str = "assert property (chk_xz_valid(%s, %s)) else $error(\"%s has xz error\");" % (signal, re.match(r"chk_xz while (.*)", note).group(1), signal)
    str = "%s: assert property (chk_xz_valid(%s, %s)) else $assertoff(0, %s);" % (assert_inst, signal, re.match(r"chk_xz while (.*)", note).group(1), assert_inst)
  else:
    #str = "assert property chk_xz(%s) else $error(\"%s has xz error\", %s);" % (signal, r"%s", signal)
    #str = "assert property (chk_xz(%s)) else $error(\"%s has xz error\");" % (signal, signal)
    str = "%s: assert property (chk_xz(%s)) else $assertoff(0, %s);" % (assert_inst, signal, assert_inst)
  return str

def add_chk_value(signal, note):
  cond = ""
  if not re.match(r"==", note):
    return -1
  #note = re.sub(r"==\s*", "", note)
  if re.search(r"while", note):
    res = re.search(r"\[(.*)\]\s+while\s+(.*)", note)
    note = res.group(1)
    cond = res.group(2)
  else:
    res = re.search(r"\[(.*)\]", note)
    note = res.group(1)
  #print(note, cond)
  note = re.sub(r"==\s*", "", note)
  
  assert_body = ""
  for value in note.split(","):
    if assert_body != "":
      assert_body += " || "
      #print(assert_body)
    if re.match(r"(.*)~(.*)", value):
      res = re.match(r"(.*)~(.*)", value)
      assert_body += "(" + signal + " >= " + res.group(1) + ")"
      if res.group(2) != "$":
        assert_body  = "(" + assert_body
        assert_body += " && "
        assert_body += "(" + signal + " <= " + res.group(2) + "))"
    else:
      assert_body += "(" + signal + " == " + value + ")"
  
  str = ""
  if cond != "":
    cond_signals = find_signals(cond)
    assert_inst = "chk_%s_valid_value_assert" % signal
    str  = "property chk_%s_valid_value(%s, %s);\n" % (signal, cond_signals, signal)
    str += "  @(posedge assert_clk) disable iff(~after_rst)\n"
    str += "  %s |-> %s;\n" % (cond, assert_body)
    str += "endproperty\n"
    #str += "assert property (chk_%s_valid_value(%s, %s)) else $error(\"%s has value error\", %s);\n" % (signal, cond, signal, r"%s", signal)
    str += "%s: assert property (chk_%s_valid_value(%s, %s)) else $assertoff(0, %s);" % (assert_inst, signal, cond_signals, signal, assert_inst)
  else:
    assert_inst = "chk_%s_value_assert" % signal
    str  = "property chk_%s_value(%s);\n" % (signal, signal)
    str += "  @(posedge assert_clk) disable iff(~after_rst)\n"
    str += "  %s;\n" % (assert_body)
    str += "endproperty\n"
    #str += "assert property (chk_%s_value(%s)) else $error(\"%s has value error\", %s);\n" % (signal, signal, r"%s", signal)
    str += "%s: assert property (chk_%s_value(%s)) else $assertoff(0, %s);" % (assert_inst, signal, signal, assert_inst)
  
  #str += "\n"
  return str

#def add_chk_end(signal, note):
#  if not re.match(r"END =", note):
#    return -1
#  note = re.sub(r"END =\s+", "", note)
#  assert_body = signal + " == " + note
#
#  assert_inst = "chk_%s_end_assert" % signal
#  str = "property chk_%s_end(%s);\n" % (signal, signal)
#  str += "  @(posedge assert_clk) disable iff(~after_rst | harness.end_of_check == 1'b0)\n"
#  str += "  %s;\n" % (assert_body)
#  str += "endproperty\n"
#  #str += "assert property (chk_%s_end(%s)) else $error(\"%s has end error\", %s);\n" % (signal, signal, r"%s", signal)
#  str += "%s: assert property (chk_%s_end(%s)) else $assertoff(0, %s);\n" % (assert_inst, signal, signal, assert_inst)
#  str += "\n"
#  return str

def add_chk_end(signal, note):
  if not re.match(r"END =", note):
    return -1
  note = re.sub(r"END =\s+", "", note)
  assert_body = signal + " !== " + note
  str = "  if({0}) $error(\"Error in end check: {0}\");".format(assert_body)
  return str

def gen_assert_code(info):
  assert_xz_list  = []
  assert_value_list = []
  assert_end_list   = []
  assert_list = ["`ifdef {0} // AUTO ADD {{{{{{".format(assert_key_word),
           "",
           "wire assert_clk   = {0};".format(assert_clk_name),
           "wire assert_rst_n = {0};".format(assert_rst_name),
           "",
           "reg after_rst = 1'b0;",
           "initial begin",
           "  wait(assert_rst_n == 1'b0);",
           "  wait(assert_rst_n == 1'b1);",
           "  repeat(10) @(posedge assert_clk);",
           "  after_rst = 1'b1;",
           "end",
           "",
           "always @(posedge assert_clk)",
           "  if(assert_rst_n == 1'b0) after_rst <= 1'b1;",
           "",
           "property chk_xz(info);",
           "  @(posedge assert_clk) disable iff(~after_rst)",
           "  ~$isunknown(info);",
           "endproperty",
           "",
           "property chk_xz_valid(info, valid);",
           "  @(posedge assert_clk) disable iff(~after_rst)",
           "  valid |-> ~$isunknown(info);",
           "endproperty",
           ""
      ]
  for signal in info:
    sig_info = info[signal]
    if 'note' in sig_info.keys():
      result = re.search(r"ASSERT: ([^\/]+)", sig_info['note'])
      if result:
        assert_note_list = result.group(1).split(r";")
        for note in assert_note_list:
          assert_note = note.strip()
          if add_chk_xz(signal, assert_note) != -1:
            assert_xz_list.append(add_chk_xz(signal, assert_note))
          if add_chk_value(signal, assert_note) != -1:
            assert_value_list.append(add_chk_value(signal, assert_note))
          if add_chk_end(signal, assert_note) != -1:
            assert_end_list.append(add_chk_end(signal, assert_note))
  
  assert_list.append("")
  assert_list.append("//CHACK SIGNAL X and Z")
  assert_list.extend(assert_xz_list)

  assert_list.append("")
  assert_list.append("//CHACK SIGNAL VALUE")
  assert_list.extend(assert_value_list)

  assert_list.append("")
  assert_list.append("//CHACK SIGNAL END VALUE")
  if len(assert_end_list) > 0:
    assert_list.append("final begin")
    assert_list.extend(assert_end_list)
    assert_list.append("end")
  assert_list.append("")

  assert_list.append("`endif //AUTO ADD {0} }}}}}}".format(assert_key_word))
  return assert_list

def write_rtl(file, assert_list):
  new_rtl = []
  ignore_flag = 0
  ignore_chg = 0
  add_flag = 0
  has_add = 0
  with open (file, "r") as rtl:
    handle = rtl.readlines()
    for line in handle:
      if ignore_chg == 1:
        ignore_flag = 0
      if re.match(r"`ifdef {0} ".format(assert_key_word), line):
        ignore_flag = 1
        add_flag = 1
      if re.match(r"endmodule", line) and has_add == 0:
        add_flag = 1
      if re.match(r"`endif //AUTO ADD {0}".format(assert_key_word), line):
        ignore_chg = 1

      if add_flag == 1:
        new_rtl.extend(assert_list)
        add_flag = 0
        has_add = 1

      if ignore_flag == 0:
        new_rtl.append(line.rstrip())
  for line in new_rtl:
    print(line)

def main():
  global assert_key_word
  rtl_file = sys.argv[1]
  assert_key_word = "AUTO_ASSERT_ON"
  if len(sys.argv) > 2:
    assert_key_word = sys.argv[2]

  signal_info = sys_rtl_file(rtl_file)
  assert_list = gen_assert_code(signal_info)
  write_rtl(rtl_file, assert_list)

if __name__ == '__main__':
  main()
